"""
OpenGauss 数据库模式编辑器
处理数据库模式的创建、修改和删除操作
"""
from django.core.exceptions import FieldDoesNotExist
from django.db.backends.postgresql import schema as postgresql_schema
from django.db.backends.ddl_references import Statement
from django.db import models


class DatabaseSchemaEditor(postgresql_schema.DatabaseSchemaEditor):
    """
    OpenGauss 数据库模式编辑器
    继承自 PostgreSQL，针对 OpenGauss 特性进行优化
    """
    
    # SQL 模板
    sql_create_sequence = "CREATE SEQUENCE %(sequence)s"
    sql_delete_sequence = "DROP SEQUENCE IF EXISTS %(sequence)s CASCADE"
    sql_set_sequence_max = "SELECT setval('%(sequence)s', MAX(%(column)s)) FROM %(table)s"
    sql_alter_column_type = "ALTER TABLE %(table)s ALTER COLUMN %(column)s TYPE %(type)s USING %(column)s::%(type)s"
    sql_alter_column_drop_not_null = "ALTER TABLE %(table)s ALTER COLUMN %(column)s DROP NOT NULL"
    sql_alter_column_set_not_null = "ALTER TABLE %(table)s ALTER COLUMN %(column)s SET NOT NULL"
    
    # OpenGauss 特有的 SQL 模板
    sql_create_index_concurrently = "CREATE INDEX CONCURRENTLY %(name)s ON %(table)s%(using)s (%(columns)s)%(extra)s"
    sql_create_unique_index = "CREATE UNIQUE INDEX %(name)s ON %(table)s (%(columns)s)%(extra)s"
    
    def __init__(self, connection, collect_sql=False, atomic=True):
        super().__init__(connection, collect_sql, atomic)
        self.sql_alter_column_collate = None  # OpenGauss 使用不同的语法
    
    def _field_data_type(self, field):
        """
        返回字段的数据类型
        针对 OpenGauss 特定类型进行映射
        """
        # 获取基础数据类型
        data_type = super()._field_data_type(field)
        
        # OpenGauss 特定类型映射
        if isinstance(field, models.JSONField):
            # OpenGauss 推荐使用 JSONB 获得更好的性能
            return 'JSONB'
        elif isinstance(field, models.TextField) and getattr(field, 'db_collation', None):
            # 处理带排序规则的文本字段
            return data_type + ' COLLATE "%s"' % field.db_collation
        
        return data_type
    
    def column_sql(self, model, field, include_default=False):
        """
        返回字段的完整列定义 SQL
        修复 serial 字段的 PRIMARY KEY 重复问题和 GENERATED BY DEFAULT 语法问题
        """
        # 对于 BigAutoField，直接使用 bigserial 而不是复杂的 GENERATED BY DEFAULT 语法
        if isinstance(field, models.BigAutoField):
            sql = 'bigserial'
            if field.primary_key:
                sql += ' PRIMARY KEY'
            return sql, []
        
        # SmallAutoField必须在AutoField之前判断，因为它也是AutoField的子类
        if isinstance(field, models.SmallAutoField):
            sql = 'smallserial'
            if field.primary_key:
                sql += ' PRIMARY KEY'
            return sql, []
        
        # 对于其他 AutoField 类型也做类似处理
        if isinstance(field, models.AutoField):
            sql = 'serial'
            if field.primary_key:
                sql += ' PRIMARY KEY'
            return sql, []
        
        # 获取基本的列 SQL
        column_result = super().column_sql(model, field, include_default)
        
        # column_sql 返回的是 (sql, params) 元组
        if isinstance(column_result, tuple):
            column_sql, params = column_result
        else:
            column_sql, params = column_result, []
        
        # 移除 GENERATED BY DEFAULT 语法，使用 serial 类型
        if 'GENERATED BY DEFAULT' in column_sql:
            if 'bigint' in column_sql.lower():
                column_sql = column_sql.split()[0] + ' bigserial'
            elif 'integer' in column_sql.lower():
                column_sql = column_sql.split()[0] + ' serial'
            elif 'smallint' in column_sql.lower():
                column_sql = column_sql.split()[0] + ' smallserial'
        
        # 如果是自增字段且已经包含 PRIMARY KEY，则移除重复的
        if (hasattr(field, 'primary_key') and field.primary_key and 
            isinstance(field, (models.AutoField, models.BigAutoField, models.SmallAutoField))):
            # 移除 Django 自动添加的 PRIMARY KEY
            column_sql = column_sql.replace(' PRIMARY KEY', '')
            
        return column_sql, params
    
    def add_field(self, model, field):
        """
        添加字段
        处理 OpenGauss 特定的默认值和约束
        """
        # 特殊处理 JSONField 的默认值
        if isinstance(field, models.JSONField) and field.has_default():
            field.default = self.connection.ops.adapt_json_value(field.default, None)
        
        super().add_field(model, field)
    
    def _create_fk_sql(self, model, field, suffix):
        """
        创建外键约束的 SQL
        OpenGauss 支持延迟约束检查
        """
        sql = super()._create_fk_sql(model, field, suffix)
        
        # 添加延迟检查选项
        if getattr(field, 'db_constraint_defer', False):
            if isinstance(sql, Statement):
                sql = Statement(
                    sql.template + " DEFERRABLE INITIALLY DEFERRED",
                    **sql.parts,
                )
            elif isinstance(sql, str):
                sql += " DEFERRABLE INITIALLY DEFERRED"

        return sql
    
    def prepare_default(self, value):
        """
        准备默认值
        处理 OpenGauss 特定的默认值格式
        """
        if isinstance(value, (dict, list)):
            # JSON 类型的默认值
            return self.quote_value(self.connection.ops.adapt_json_value(value, None))
        return super().prepare_default(value)
    
    def _model_indexes_sql(self, model):
        """
        生成模型索引的 SQL
        优化索引创建顺序，先创建普通索引，再创建唯一索引
        """
        output = []
        unique_indexes = []
        normal_indexes = []
        
        for index in model._meta.indexes:
            is_unique_candidate = False
            if index.fields:
                try:
                    is_unique_candidate = all(
                        isinstance(field, str)
                        and model._meta.get_field(field.lstrip('-')).unique
                        for field in index.fields
                    )
                except FieldDoesNotExist:
                    is_unique_candidate = False
            if is_unique_candidate:
                unique_indexes.append(index)
            else:
                normal_indexes.append(index)
        
        # 先创建普通索引
        for index in normal_indexes:
            output.append(index.create_sql(model, self))
        
        # 再创建唯一索引
        for index in unique_indexes:
            output.append(index.create_sql(model, self))
        
        return output
    
    def execute(self, sql, params=()):
        """
        执行 SQL 语句，预处理 OpenGauss 不兼容的语法
        添加 OpenGauss 特定的错误处理
        """
        # 修复 GENERATED BY DEFAULT AS IDENTITY 语法问题 (仅对CREATE/ALTER语句)
        if isinstance(sql, str) and ('CREATE TABLE' in sql or 'ALTER TABLE' in sql) and 'GENERATED BY DEFAULT' in sql:
            # 替换 bigserial PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY 为 bigserial PRIMARY KEY
            sql = sql.replace(' GENERATED BY DEFAULT AS IDENTITY', '')
            # 替换其他可能的 GENERATED BY DEFAULT 语法
            sql = sql.replace(' GENERATED BY DEFAULT', '')
        
        try:
            super().execute(sql, params)
        except Exception as e:
            # 处理 OpenGauss 特定的错误
            error_message = str(e)
            if "already exists" in error_message:
                # 对象已存在，可能需要特殊处理
                if self.collect_sql:
                    # 收集 SQL 模式下忽略错误
                    pass
                else:
                    raise
            else:
                raise
